{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6adfa9fa-b760-49d8-be45-65fb67c5ab48",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## vllm rollingbatch Mixtral-8x7B_AWQ deployment guide with Inference recommender\n",
    "In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.\n",
    "\n",
    "Please make sure the following permission granted before running the notebook:\n",
    "\n",
    "* SageMaker access\n",
    "\n",
    "## Step 1: Let's bump up SageMaker and import stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59dcc3aa-2cf6-44fc-95d8-d3fc819b5593",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "%pip install sagemaker --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e45b3aab-7136-4c94-8c60-20a40191d08f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sagemaker\n",
    "from sagemaker.djl_inference.model import DJLModel\n",
    "\n",
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "session = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b0f3f23-33ef-4a39-98fc-cbe03e30fcd6",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 2: Start building SageMaker endpoint\n",
    "In this step, we will build SageMaker endpoint from scratch\n",
    "\n",
    "### Getting the container image URI (optional)\n",
    "\n",
    "Check out available images: [Large Model Inference available DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82e7b33e-87f5-4993-a762-ed13850f43b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Choose a specific version of LMI image directly:\n",
    "# image_uri = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a38d385-c4b2-4ad2-a669-06a49e5d0e55",
   "metadata": {},
   "source": [
    "### Create SageMaker model\n",
    "\n",
    "Here we are using [LMI PySDK](https://sagemaker.readthedocs.io/en/stable/frameworks/djl/using_djl.html) to create the model.\n",
    "\n",
    "Checkout more [configuration options](https://docs.djl.ai/docs/serving/serving/docs/lmi/deployment_guide/configurations.html#environment-variable-configurations)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b825f1b5-0d86-4f6b-91db-8ecfb361066c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model_id = \"TheBloke/mixtral-8x7b-v0.1-AWQ\" # model will be download form Huggingface hub\n",
    "\n",
    "env = {\n",
    "    \"TENSOR_PARALLEL_DEGREE\": \"4\",         # requires 4 GPUs, set to \"max\" to use all GPUs on the instance\n",
    "    \"OPTION_ROLLING_BATCH\": \"lmi-dist\",    # enable rollingbatch with lmi-dist\n",
    "    \"OPTION_MAX_MODEL_LEN\": \"25456\",\n",
    "}\n",
    "\n",
    "model = DJLModel(\n",
    "            model_id=model_id,\n",
    "            env=env,\n",
    "            role=role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e6e5f3b-0450-4365-b67b-d2a9f2cfcac4",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Create SageMaker endpoint\n",
    "\n",
    "You need to specify the instance to use and endpoint names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "765de906-9747-4d69-aa78-0b0a359bd57f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "instance_type = \"ml.g5.12xlarge\"\n",
    "endpoint_name = sagemaker.utils.name_from_base(\"lmi-model\")\n",
    "\n",
    "predictor = model.deploy(initial_instance_count=1,\n",
    "             instance_type=instance_type,\n",
    "             endpoint_name=endpoint_name,\n",
    "             # container_startup_health_check_timeout=3600,\n",
    "            )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5e60c73-f3d8-43ef-9261-23677f03d5cb",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Step 3: Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e21c0102-132c-4c18-8627-223592578c86",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "predictor.predict(\n",
    "    {\"inputs\": \"The future of Artificial Intelligence is\", \"parameters\": {\"max_new_tokens\":128, \"do_sample\":\"true\"}}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ed2e9a6-ef76-4a35-ae97-23dafb8709ce",
   "metadata": {},
   "source": [
    "## Step 4 Benchmark with Inference recommender\n",
    "\n",
    "Inference recommender for Mixtral with LMI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "818428d5-86cf-4335-96b4-b3244e68ea2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "payload = {\"inputs\": \"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.:\\n\\nAlso, with information from a knowledge library, without any retraining, see retrieval augmented generation. Prompt engineering is the process of designing and refining the prompts or input for a large model to generate specific types of output. Prompt engineering involves selecting appropriate keywords, providing context, and shaping the input in a way that encourages the model to produce the desired response and is the vital technique to actively shape the behavior and output of foundation models. Effective prompt engineering is crucial for directing model behavior and achieving desired responses. Through prompt engineering, you can control a model's tone, style, and domain expertise without more involved customization measures like fine-tuning. The goal is to provide sufficient context and guidance to the model so that it can generalize and perform well on unseen or limited data scenarios. Fine-tuning a pre-trained foundation model is an affordable way to take advantage of their broad capabilities while customizing a model on your own small corpus. Fine-tuning is the customization method that involved further training and does change the weights of your model. Fine-tuning might be useful.\\n\\ncustomization method that involved further training and does change the weights of your model. Fine tuning might be useful.\\n\\nyou if you need to customize your model to specific business needs, your model to successfully work with domain-specific language such as industry jargon, technical terms, or other specialized vocabulary. Enhanced performance for specific tasks. Accurate, relative, and context-aware responses in applications. Responses that are more factual, less toxic, and better aligned to specific requirement. There are two main approaches that you can take for fine-tuning depending on your use case and chosen foundation model. If you are interested in fine-tuning your model on domain-specific data, see domain adaptation fine-tuning. If you are interested in instruction-based fine-tuning, using prompt and response examples, see instruction-based fine-tuning. Retrieval Augmented generation. Foundation models are usually trained offline, making the model agnostic to any data that is created after the model was trained. Additionally, foundation models are trained on very general domain corpora, making them less effective for domain-specific tasks. You can use Retrieval Augmented Generation, or RAG, to retrieve data from outside of foundation model and augment your prompts by adding the relevant retrieval data in context. For more information about RAG model architectures, see Retrieval Augmented Generation for Knowledge-Intensive NLP Tasks. With RAG, the external data used to augment your prompts can come from multiple data source, such as.\\n\\nAmazon SageMaker is a fully managed machine learning service. With SageMaker, data scientists and developers can quickly and easily build and train machine learning models, and then directly deploy them into a production-ready hosted environment. It provides an integrated Jupyter authoring notebook instance for easy access for your data source for exploration and analysis. So you don't have to manage servers. It also provides common machine learning algorithms that are optimized to run efficiently against extremely large data in a distributed environment. With native support for bring your own algorithms and frameworks, SageMaker offers flexible distributed training options that adjust to your specific workloads. Deploy a model into a secure and scalable environment by launching it with a few clicks from SageMaker Studio or the SageMaker Console. Nation models are extremely powerful models able to solve a wide array of tasks. To solve most tasks efficiently, these models require some form of customization. The recommended way to first customize a foundation model to a specific use case is through prompt engineering. Providing your foundation model with well-engineered, context-rich prompts can help achieve desired results without any fine-tuning or changing of model weights. For more information, see prompt engineering for foundation models.\\n\\nIf prompt engineering alone is not enough to customize your foundation model to a specific task, you can fine-tune the foundation model on additional domain-specific data. The fine-tuning process involves changing model weights. If you want to customize your model\\n\\nQuestion: how to prompt engineer\\nHelpful Answer:\",\"parameters\": {\"max_new_tokens\": 1200, \"do_sample\":True}}\n",
    "\n",
    "with open('payload2.json', 'w') as f:\n",
    "    json.dump(payload, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac650915-bdb1-4af0-a077-ed435437a766",
   "metadata": {},
   "outputs": [],
   "source": [
    "!tar -czvf payload2.tar.gz payload2.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ddf2c38-8423-4219-b5f2-93cd0ea58d69",
   "metadata": {},
   "outputs": [],
   "source": [
    "bucket = session.default_bucket()  # bucket to house artifacts\n",
    "s3_location = f\"s3://{bucket}/sagemaker/InferenceRecommender/djl-inference\"\n",
    "payload_tar_url = sagemaker.s3.S3Uploader.upload(\"payload2.tar.gz\", s3_location)\n",
    "\n",
    "payload_tar_url"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90348e37-8e68-453b-815b-04e1c8fbbc88",
   "metadata": {},
   "outputs": [],
   "source": [
    "mixtral_endpoint_name = predictor.endpoint_name\n",
    "mixtral_model_name = model.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d375da2-2641-4566-b1b7-1fe0e8f84b52",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time \n",
    "import boto3\n",
    "\n",
    "sm_client = boto3.client('sagemaker')\n",
    "\n",
    "job_name = f\"Mixtral-8x7B-v01-awq-\" + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n",
    "response = sm_client.create_inference_recommendations_job(\n",
    "    JobName=job_name,\n",
    "    JobType='Default',\n",
    "    RoleArn=role,\n",
    "    InputConfig={\n",
    "        'ContainerConfig': {\n",
    "            'Domain': 'NATURAL_LANGUAGE_PROCESSING',\n",
    "            'Task': 'TEXT_GENERATION',\n",
    "            'PayloadConfig': {\n",
    "                'SamplePayloadUrl': payload_tar_url,\n",
    "                'SupportedContentTypes': [\"application/json\"],\n",
    "            },\n",
    "            #specify the instance types you would like to test out\n",
    "            'SupportedInstanceTypes': ['ml.g5.12xlarge'], \n",
    "            'SupportedEndpointType': 'RealTime'\n",
    "        },\n",
    "        'ModelName': mixtral_model_name,\n",
    "    \"Endpoints\": [ \n",
    "         { \n",
    "            \"EndpointName\": mixtral_endpoint_name\n",
    "         }\n",
    "      ],\n",
    "    },    \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2b745a6-c445-4d8e-8c21-ba34d8983249",
   "metadata": {},
   "outputs": [],
   "source": [
    "describe_IR_job_response = sm_client.describe_inference_recommendations_job(JobName=job_name)\n",
    "\n",
    "while describe_IR_job_response[\"Status\"] in [\"IN_PROGRESS\", \"PENDING\"]:\n",
    "    describe_IR_job_response = sm_client.describe_inference_recommendations_job(JobName=job_name)\n",
    "    print(describe_IR_job_response[\"Status\"])\n",
    "    time.sleep(30)\n",
    "\n",
    "print(f'Inference Recommender job {job_name} has finished with status {describe_IR_job_response[\"Status\"]}.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d00ee048-7cc6-4eec-bfce-799cdcd1f24e",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = describe_IR_job_response['InferenceRecommendations'][0]['Metrics']\n",
    "token_per_sec = round(metrics['MaxInvocations']*1550/60, 2)\n",
    "cost_per_sec = round(metrics['CostPerHour']/3600, 5)\n",
    "cost_per_1k_token = round(cost_per_sec/token_per_sec * 1000, 5)\n",
    "print(\"According to the Inference recommender job, the corresponding metrices are as below: /n\")\n",
    "print(f\"Max tokens per second is about {token_per_sec}\")\n",
    "print(f\"Cost per second is about ${cost_per_sec}\")\n",
    "print(f\"Cost per 1k tokens is about ${cost_per_1k_token}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0eea92a9-2d79-483e-9dc3-ee458a8d2206",
   "metadata": {},
   "source": [
    "example output:\n",
    "\n",
    "According to the Inference recommender job, the corresponding metrices are as below: /n\n",
    "Max tokens per second is about 775.0\n",
    "Cost per second is about $0.00197\n",
    "Cost per 1k tokens is about $0.00254"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f69dbef-d858-46ed-a163-0a15c65056ca",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Clean up the environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d56c6f57-4554-4ea1-ad58-f0a12e294d44",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "session.delete_endpoint(endpoint_name)\n",
    "session.delete_endpoint_config(endpoint_name)\n",
    "model.delete_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dce0193c-bf0e-421e-a30b-e0b2935b92b3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
